import argparse
import torch as th

import dgl
from dgl.dataloading import GraphDataLoader
import warnings

from dataset import load
warnings.filterwarnings('ignore')

from utils import linearsvc
from model import MVGRL

parser = argparse.ArgumentParser(description='mvgrl')

parser.add_argument('--dataname', type=str, default='MUTAG', help='Name of dataset.')
parser.add_argument('--gpu', type=int, default=-1, help='GPU index. Default: -1, using cpu.')
parser.add_argument('--epochs', type=int, default=200, help=' Number of training periods.')
parser.add_argument('--patience', type=int, default=20, help='Early stopping steps.')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate of mvgrl.')
parser.add_argument('--wd', type=float, default=0., help='Weight decay of mvgrl.')
parser.add_argument('--batch_size', type=int, default=64, help='Batch size.')
parser.add_argument('--n_layers', type=int, default=4, help='Number of GNN layers.')
parser.add_argument("--hid_dim", type=int, default=32, help='Hidden layer dim.')

args = parser.parse_args()

# check cuda
if args.gpu != -1 and th.cuda.is_available():
    args.device = 'cuda:{}'.format(args.gpu)
else:
    args.device = 'cpu'


def collate(samples):
    ''' collate function for building the graph dataloader'''
    graphs, diff_graphs, labels = map(list, zip(*samples))

    # generate batched graphs and labels
    batched_graph = dgl.batch(graphs)
    batched_labels = th.tensor(labels)
    batched_diff_graph = dgl.batch(diff_graphs)

    n_graphs = len(graphs)
    graph_id = th.arange(n_graphs)
    graph_id = dgl.broadcast_nodes(batched_graph, graph_id)

    batched_graph.ndata['graph_id'] = graph_id

    return batched_graph, batched_diff_graph, batched_labels

if __name__ == '__main__':

    # Step 1: Prepare data =================================================================== #
    dataset = load(args.dataname)

    graphs, diff_graphs, labels = map(list, zip(*dataset))
    print('Number of graphs:', len(graphs))
    # generate a full-graph with all examples for evaluation

    wholegraph = dgl.batch(graphs)
    whole_dg = dgl.batch(diff_graphs)

    # create dataloader for batch training
    dataloader = GraphDataLoader(dataset,
                                 batch_size=args.batch_size,
                                 collate_fn=collate,
                                 drop_last=False,
                                 shuffle=True)

    in_dim = wholegraph.ndata['feat'].shape[1]

    # Step 2: Create model =================================================================== #
    model = MVGRL(in_dim, args.hid_dim, args.n_layers)
    model = model.to(args.device)

    # Step 3: Create training components ===================================================== #
    optimizer = th.optim.Adam(model.parameters(), lr=args.lr)

    print('===== Before training ======')

    wholegraph = wholegraph.to(args.device)
    whole_dg = whole_dg.to(args.device)
    wholefeat = wholegraph.ndata.pop('feat')
    whole_weight = whole_dg.edata.pop('edge_weight')

    embs = model.get_embedding(wholegraph, whole_dg, wholefeat, whole_weight)
    lbls = th.LongTensor(labels)
    acc_mean, acc_std = linearsvc(embs, lbls)
    print('accuracy_mean, {:.4f}'.format(acc_mean))

    best = float('inf')
    cnt_wait = 0
    # Step 4: Training epochs =============================================================== #
    for epoch in range(args.epochs):
        loss_all = 0
        model.train()

        for graph, diff_graph, label in dataloader:
            graph = graph.to(args.device)
            diff_graph = diff_graph.to(args.device)

            feat = graph.ndata['feat']
            graph_id = graph.ndata['graph_id']
            edge_weight = diff_graph.edata['edge_weight']
            n_graph = label.shape[0]

            optimizer.zero_grad()
            loss = model(graph, diff_graph, feat, edge_weight, graph_id)
            loss_all += loss.item()
            loss.backward()
            optimizer.step()

        print('Epoch {}, Loss {:.4f}'.format(epoch, loss_all))

        if loss < best:
            best = loss
            best_t = epoch
            cnt_wait = 0
            th.save(model.state_dict(), f'{args.dataname}.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            print('Early stopping')
            break

    print('Training End')

    # Step 5:  Linear evaluation ========================================================== #
    model.load_state_dict(th.load(f'{args.dataname}.pkl'))
    embs = model.get_embedding(wholegraph, whole_dg, wholefeat, whole_weight)

    acc_mean, acc_std = linearsvc(embs, lbls)
    print('accuracy_mean, {:.4f}'.format(acc_mean))